import sys
from pathlib import Path

curpath = Path(__file__).parent.resolve()
module_path = str(curpath.joinpath("./modules"))
sys.path.insert(0, module_path)

import os
import glob
import datasets
import numpy as np
import random
import re

from tree_parsing import treeparsing_utils as parseutils


NPROC = 32
# DOC_KEY = "func_documentation_string"
# CODE_KEY = "func_code_string"
# these keys are in the original dataset, while the above keys are in the huggingface version
RAW_DOCSTRING_KEY = "docstring"
DOC_KEY = "docstring_tokens"
CODE_KEY = "function"
LANG_KEY = "language"
TRAVERSAL_TYPE = "preorder_dfs_nodeleaf_toks"
NODE_START_STR = "(_."
NODE_END_STR = "._)"
KEEPCOLS = ["language", "function", "url"]


def get_data_files(lang, randomize):
    LANGS = ["go", "java", "javascript", "ruby", "php", "python"]
    basepath = curpath.joinpath("./data_csn")
    paths = {}

    if lang == "all":
        for split in ["train", "valid", "test"]:
            splitpaths = []
            for currlang in LANGS:
                splitpaths.extend(glob.glob(os.path.join(str(basepath), f"{currlang}/final/jsonl/{split}/*.jsonl.gz")))
            # print(f"Paths: {lang}:{currlang} / {split} / {splitpaths}")
            paths[split] = [x for x in splitpaths]
    else:
        for split in ["train", "valid", "test"]:
            splitpaths = glob.glob(os.path.join(str(basepath), f"{lang}/final/jsonl/{split}/*.jsonl.gz"))
            # print(f"Paths: {lang} / {split} / {splitpaths}")
            paths[split] = [x for x in splitpaths]

    if randomize:
        for split, files in paths.items():
            paths[split] = list(np.random.permutation(files))

    # print(f"directory paths: {paths}")

    return paths


def load_codesearchnet_data(lang):
    if lang == "all":
        datapaths = {
            "train": "./data_csn/codesearchnet/codesearchnet_train.jsonl",
            "valid": "./data_csn/codesearchnet/codesearchnet_valid.jsonl",
        }
        print(datapaths)
        ds_train = datasets.load_dataset("json", data_files=datapaths["train"], split="train")
        ds_val = datasets.load_dataset("json", data_files=datapaths["valid"], split="train")

        ds = datasets.DatasetDict()
        ds["train"] = ds_train
        ds["valid"] = ds_val

        print(ds)
    else:
        raise AssertionError(f"{lang} codesearchnet data load not supported")

    return ds


def clean_doc_from_csn_code_samples(code, doc, lang):
    if lang != "python":
        return code

    # check if doc actually exists in the code
    if not doc or doc not in code:
        return code

    code = code.replace(doc, "")
    re_emptydoc = r'("""[ \n]*""")'
    code = re.sub(re_emptydoc, "", code)
    return code


# Define all datasets
# TASK: codesearchnet natural language to tree task
def codesearchnet_nl2tree():
    def tree_nl2pl(examples):
        example_ip, example_op = [], []
        bsz = len(examples[DOC_KEY])

        for idx in range(bsz):
            if len(examples[DOC_KEY][idx]) == 0:
                continue

            doc = " ".join(examples[DOC_KEY][idx])
            lang = examples[LANG_KEY][idx]
            code = examples[CODE_KEY][idx]
            rawdoc = examples[RAW_DOCSTRING_KEY][idx]
            code = clean_doc_from_csn_code_samples(code, rawdoc, lang)

            treestr = parsecode(code, lang, add_cls=True)

            if treestr is None:
                continue

            ip = f"en {doc}"
            op = f"{lang} tree {treestr}"

            example_ip.append(ip)
            example_op.append(op)

        result_dict = {"inputs": example_ip, "outputs": example_op}

        return result_dict

    ds = load_codesearchnet_data("all")
    # cols = [x for x in ds["train"].features.keys() if x not in KEEPCOLS]
    cols = list(ds["train"].features.keys())
    ds = ds.map(tree_nl2pl, batched=True, remove_columns=cols, num_proc=NPROC)

    return ds


# TASK: codesearchnet tree to natural language task
def codesearchnet_tree2nl():
    def tree_pl2nl(examples):
        example_ip, example_op = [], []
        bsz = len(examples[DOC_KEY])

        for idx in range(bsz):
            if len(examples[DOC_KEY][idx]) == 0:
                continue

            doc = " ".join(examples[DOC_KEY][idx])
            lang = examples[LANG_KEY][idx]
            code = examples[CODE_KEY][idx]
            rawdoc = examples[RAW_DOCSTRING_KEY][idx]
            code = clean_doc_from_csn_code_samples(code, rawdoc, lang)
            treestr = parsecode(code, lang, add_cls=True)

            if treestr is None:
                continue

            ip = f"{lang} tree {treestr}"
            op = f"en {doc}"

            example_ip.append(ip)
            example_op.append(op)

        result_dict = {"inputs": example_ip, "outputs": example_op}

        return result_dict

    ds = load_codesearchnet_data("all")
    cols = list(ds["train"].features.keys())
    ds = ds.map(tree_pl2nl, batched=True, remove_columns=cols, num_proc=NPROC)
    return ds


# TASK: Masked Node Prediction on non-terminal nodes
def codesearchnet_masked_node_pred(tokenizer):
    def create_ip_op_pair(tstr, mask_tokens):
        nodes = tstr.split(" ")
        node_mask_map = {}
        n_masktokens = len(mask_tokens)
        current_maskidx = 0
        new_ip, new_op = [], []

        for node in nodes:
            # node is not a structure leaf
            if not (node.lstrip().startswith(NODE_START_STR) or node.rstrip().endswith(NODE_END_STR)):
                new_ip.append(node)
                continue

            # node is a structure leaf and needs to be masked out
            # Assign current node a mask sentinel token if one does not already exist
            if node in node_mask_map:
                new_ip.append(node_mask_map[node])

            elif node not in node_mask_map:
                if current_maskidx >= n_masktokens:
                    new_ip.append(node)
                else:
                    mask_tok = mask_tokens[current_maskidx]
                    current_maskidx += 1
                    node_mask_map[node] = mask_tok

                    new_ip.append(mask_tok)
                    new_op.extend([mask_tok, node])

        new_ip = " ".join(new_ip)
        new_op = " ".join(new_op)

        return new_ip, new_op

    def masked_node_pred(examples, mask_tokens):
        example_ip, example_op = [], []
        bsz = len(examples[DOC_KEY])

        for idx in range(bsz):
            lang = examples[LANG_KEY][idx]
            code = examples[CODE_KEY][idx]
            rawdoc = examples[RAW_DOCSTRING_KEY][idx]
            code = clean_doc_from_csn_code_samples(code, rawdoc, lang)
            treestr = parsecode(code, lang, add_cls=True)

            if treestr is None:
                continue

            ip, op = create_ip_op_pair(treestr, mask_tokens)
            ip = f"{lang} {ip}"

            example_ip.append(ip)
            example_op.append(op)

        result_dict = {"inputs": example_ip, "outputs": example_op}

        return result_dict

    mask_tokens = sorted(tokenizer.additional_special_tokens)
    ds = load_codesearchnet_data("all")
    cols = list(ds["train"].features.keys())
    ds = ds.map(
        masked_node_pred,
        batched=True,
        fn_kwargs={"mask_tokens": mask_tokens},
        remove_columns=cols,
        num_proc=NPROC,
    )
    return ds


# TASK: CodeSearchNet Masked sub-tree prediction
def codesearchnet_masked_subtree_pred(tokenizer, maxlen, mask_rate):
    def get_subtrees(tokens, mask_len, maxlen):
        subtrees = []
        tree_open_idxs = []

        for idx in range(min(len(tokens), maxlen)):
            curr_token = tokens[idx].strip()

            if curr_token.startswith(NODE_START_STR):
                tree_open_idxs.append(idx)

            elif curr_token.endswith(NODE_END_STR):
                last_open_node = tree_open_idxs.pop(-1)
                treelen = idx - last_open_node + 1

                # if current tree length is greater than allowed maximum mask length
                # then skip processing this tree
                if treelen > mask_len:
                    continue

                subtrees.append(
                    (last_open_node, idx + 1, treelen)
                )  # add (startidx, endidx, tree_length) to subtrees list

        return subtrees

    def create_ip_op_pair(tokens, subtrees, mask_tokens, mask_len):
        # subtree overlapping condition
        is_y_overlapping_x = lambda x, y: (x[0] <= y[0] <= x[1] or x[0] <= y[1] <= x[1])

        # sample subtrees:
        subtrees_perm = [x for x in subtrees if x[2] <= mask_len]
        random.shuffle(subtrees_perm)
        sampled_trees = []

        # for tree_ds in subtrees_perm:
        while subtrees_perm and mask_len > 0:
            tree_ds = subtrees_perm.pop(0)

            if tree_ds[2] <= mask_len and True not in [is_y_overlapping_x(x, tree_ds) for x in sampled_trees]:
                sampled_trees.append(tree_ds)
                mask_len -= tree_ds[2]

        # Limit number of sampled trees to maximum number of mask tokens
        sampled_trees = sampled_trees[: len(mask_tokens)]
        sampled_trees = sorted(sampled_trees, key=lambda x: x[0])
        mask_idx = 0
        new_ip, new_op = [], []
        last_idx = 0

        for tree_ds in sampled_trees:
            startidx, endidx, treelen = tree_ds
            new_ip.extend(tokens[last_idx:startidx] + [mask_tokens[mask_idx]])
            new_op.extend([mask_tokens[mask_idx]] + tokens[startidx:endidx])
            last_idx = endidx
            mask_idx += 1

        new_ip = tokenizer.convert_tokens_to_string(new_ip)
        new_op = tokenizer.convert_tokens_to_string(new_op)
        return new_ip, new_op

    def masked_subtree(examples, tokenizer, maxlen, mask_tokens, mask_rate):
        example_ip, example_op = [], []
        bsz = len(examples[DOC_KEY])

        for idx in range(bsz):
            lang = examples[LANG_KEY][idx]
            code = examples[CODE_KEY][idx]
            rawdoc = examples[RAW_DOCSTRING_KEY][idx]
            code = clean_doc_from_csn_code_samples(code, rawdoc, lang)
            treestr = parsecode(code, lang, add_cls=True)

            if treestr is None:
                continue

            treestr = " " + treestr.strip()

            token_ids = tokenizer(treestr)["input_ids"]
            # filter <bos> token
            if token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]

            tokens = tokenizer.convert_ids_to_tokens(token_ids)
            total_len = len(tokens)
            mask_len = int(mask_rate * total_len)

            subtrees = get_subtrees(tokens, mask_len, maxlen)
            ip, op = create_ip_op_pair(tokens, subtrees, mask_tokens, mask_len)
            ip = f"{lang} {ip}"

            example_ip.append(ip)
            example_op.append(op)

        result_dict = {"inputs": example_ip, "outputs": example_op}

        return result_dict

    mask_tokens = sorted(tokenizer.additional_special_tokens)
    ds = load_codesearchnet_data("all")
    cols = list(ds["train"].features.keys())
    ds = ds.map(
        masked_subtree,
        batched=True,
        remove_columns=cols,
        fn_kwargs={
            "tokenizer": tokenizer,
            "maxlen": maxlen,
            "mask_tokens": mask_tokens,
            "mask_rate": mask_rate,
        },
        num_proc=NPROC,
    )
    return ds


# TASK: CodeSearchNet Next Token Prediction
def codesearchnet_next_token_prediction():
    def next_token_prediction(examples):
        example_ip, example_op = [], []
        bsz = len(examples[DOC_KEY])

        for idx in range(bsz):
            if len(examples[DOC_KEY][idx]) == 0:
                doc = None
            else:
                doc = " ".join(examples[DOC_KEY][idx])

            lang = examples[LANG_KEY][idx]
            code = examples[CODE_KEY][idx]
            rawdoc = examples[RAW_DOCSTRING_KEY][idx]
            code = clean_doc_from_csn_code_samples(code, rawdoc, lang)
            treestr = parsecode(code, lang, add_cls=True)

            if treestr is None:
                continue

            if doc is None:
                ip = f"<tree> {treestr}"
            else:
                ip = f"<doc> {doc} <tree> {treestr}"

            op = ip
            example_ip.append(ip)
            example_op.append(op)

        result_dict = {"inputs": example_ip, "outputs": example_op}
        return result_dict

    ds = load_codesearchnet_data("all")
    cols = list(ds["train"].features.keys())
    ds = ds.map(next_token_prediction, batched=True, remove_columns=cols, num_proc=NPROC)
    return ds


def get_task_datasets(tasklist, tokenizer, maxlen, mask_rate):
    task_datasets = []

    # load all tasks
    if "codesearchnet_nl2tree" in tasklist:
        print("loading task: codesearchnet_nl2tree")
        task_datasets.append(codesearchnet_nl2tree())
    if "codesearchnet_tree2nl" in tasklist:
        print("loading task: codesearchnet_tree2nl")
        task_datasets.append(codesearchnet_tree2nl())
    if "codesearchnet_maskednodepred" in tasklist:
        print("loading task: codesearchnet_maskednodepred")
        task_datasets.append(codesearchnet_masked_node_pred(tokenizer))
    if "codesearchnet_subtreepred" in tasklist:
        print("loading task: codesearchnet_subtreepred")
        task_datasets.append(codesearchnet_masked_subtree_pred(tokenizer, maxlen, mask_rate))
    if "codesearchnet_nexttokenprediction" in tasklist:
        print("loading task: codesearchnet_nexttokenprediction")
        task_datasets.append(codesearchnet_next_token_prediction())

    if not task_datasets:
        raise AssertionError("Datasets final set is empty")

    return task_datasets


def tokenize(examples, tokenizer, maxlen):
    inputs = examples["inputs"]
    outputs = examples["outputs"]
    model_inputs = tokenizer(inputs, text_target=outputs, truncation=True, max_length=maxlen)
    model_inputs["input_length"] = [len(x) for x in model_inputs["input_ids"]]
    model_inputs["output_length"] = [len(x) for x in model_inputs["labels"]]
    return model_inputs


def get_dataset(
    tasklist,
    tokenizer,
    maxlen,
    mask_rate,
    nsamples_train=None,
    nsamples_val=None,
    nproc=None,
):
    global NPROC
    if nproc is not None:
        NPROC = nproc

    task_datasets = get_task_datasets(tasklist, tokenizer, maxlen, mask_rate)
    ds_train = [ds["train"] for ds in task_datasets]
    ds_val = [ds["valid"] for ds in task_datasets]
    print("task loading complete")

    # iterleave all tasks sequentially
    ds_final_train = datasets.interleave_datasets(ds_train, stopping_strategy="all_exhausted")
    ds_final_val = datasets.interleave_datasets(ds_val, stopping_strategy="all_exhausted").shuffle(42)
    print("task interleaving complete")

    if nsamples_train is not None and nsamples_train > 0:
        ds_final_train = ds_final_train.take(nsamples_train)
    if nsamples_val is not None and nsamples_val > 0:
        ds_final_val = ds_final_val.take(nsamples_val)

    cols_train = list(ds_final_train.features.keys())
    cols_val = list(ds_final_val.features.keys())

    # tokenize dataset
    print("tokenizing dataset")
    ds_final_train = ds_final_train.map(
        tokenize,
        batched=True,
        remove_columns=cols_train,
        fn_kwargs={"tokenizer": tokenizer, "maxlen": maxlen},
    )

    print("train dataset tokenized. tokenizing validation dataset")
    ds_final_val = ds_final_val.map(
        tokenize,
        batched=True,
        remove_columns=cols_val,
        fn_kwargs={"tokenizer": tokenizer, "maxlen": maxlen},
    )

    return ds_final_train, ds_final_val


def get_training_dataset(
    tasklist,
    datadir,
    nsamples_train=None,
    nsamples_val=None,
):
    DIRMAP = {
        "codesearchnet_nl2tree": "CSN-codesearchnet_nl2tree",
        "codesearchnet_tree2nl": "CSN-codesearchnet_tree2nl",
        "codesearchnet_maskednodepred": "CSN-codesearchnet_maskednodepred",
        "codesearchnet_subtreepred": "CSN-codesearchnet_subtreepred",
    }

    task_datasets = []
    for task in tasklist:
        taskpath = os.path.join(datadir, DIRMAP[task])
        print(f"Loading task: {task}. Datapaths: {taskpath}")
        task_datasets.append(datasets.load_from_disk(os.path.join(datadir, DIRMAP[task])))

        print(f"Loading task: {task} completed")

    ds_train = [ds["train"] for ds in task_datasets]
    ds_val = [ds["valid"] for ds in task_datasets]
    print(f"task loading completed. Total {len(ds_train)} tasks loaded")

    # iterleave all tasks sequentially
    print("Interleaving train dataset")
    ds_final_train = datasets.interleave_datasets(ds_train, stopping_strategy="all_exhausted")

    print("Interleaving validation dataset")
    ds_final_val = datasets.interleave_datasets(ds_val, stopping_strategy="all_exhausted")
    print("task interleaving complete")

    if nsamples_train is not None and nsamples_train > 0:
        # ds_final_train = ds_final_train.take(nsamples_train)
        ds_final_train = ds_final_train.select(range(nsamples_train))
    if nsamples_val is not None and nsamples_val > 0:
        # ds_final_val = ds_final_val.take(nsamples_val)
        ds_final_val = ds_final_val.select(range(nsamples_val))

    return ds_final_train, ds_final_val


##------------------------------------------------##
# util functions


def parsecode(code, lang, add_cls):
    try:
        root = parseutils.create_TS_tree(code, lang, add_cls)
        croot = parseutils.create_custom_tree(root, lang)
        tstr = parseutils.serialize(croot, TRAVERSAL_TYPE)
    except Exception as err:
        return None
    else:
        return tstr
